from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from distutils.spawn import find_executable
from distutils import sysconfig, log
import setuptools
import setuptools.command.build_py
import setuptools.command.develop
import setuptools.command.build_ext
import setuptools.command.install

from collections import namedtuple
from contextlib import contextmanager
import glob
import os
import shlex
import subprocess
import sys
from textwrap import dedent
import multiprocessing

try:
    import torch
except ImportError as e:
    print('Unable to import torch. Error:')
    print('\t', e)
    print('You need to install pytorch first.')
    sys.exit(1)


TOP_DIR = os.path.realpath(os.path.dirname(__file__))
SRC_DIR = os.path.join(TOP_DIR, 'torch_tvm')
CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'build')

CMAKE = find_executable('cmake') or find_executable('cmake3')
if not CMAKE:
    print('Could not find "cmake". Make sure it is in your PATH.')
    sys.exit(1)

llvm_config_versions = ['', '-7', '-6.0']
for ver in llvm_config_versions:
    LLVM_CONFIG = find_executable('llvm-config{}'.format(ver))
    if LLVM_CONFIG:
        break
else:
    print('Could not find "llvm-config". Make sure it is in your PATH.')
    sys.exit(1)

install_requires = []
setup_requires = []
tests_require = []
extras_require = {}

################################################################################
# Global flags and environment variables
################################################################################

DEBUG = bool(os.getenv('DEBUG'))
RERUN_CMAKE = False
filtered_args = []
for i, arg in enumerate(sys.argv):
    if arg == '--cmake':
        RERUN_CMAKE = True
        continue
    filtered_args.append(arg)
sys.argv = filtered_args

################################################################################
# Version
################################################################################

try:
    git_version = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
                                          cwd=TOP_DIR).decode('ascii').strip()
except (OSError, subprocess.CalledProcessError):
    git_version = None

with open(os.path.join(TOP_DIR, 'VERSION_NUMBER')) as version_file:
    VersionInfo = namedtuple('VersionInfo', ['version', 'git_version'])(
        version=version_file.read().strip(),
        git_version=git_version
    )

################################################################################
# Utilities
################################################################################


@contextmanager
def cd(path):
    if not os.path.isabs(path):
        raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
    orig_path = os.getcwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(orig_path)


################################################################################
# Customized commands
################################################################################

class build_py(setuptools.command.build_py.build_py):
    def run(self):
        with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f:
            f.write(dedent('''\
            # This file is generated by setup.py. DO NOT EDIT!

            from __future__ import absolute_import
            from __future__ import division
            from __future__ import print_function
            from __future__ import unicode_literals

            version = '{version}'
            git_version = '{git_version}'
            '''.format(**dict(VersionInfo._asdict()))))
        return setuptools.command.build_py.build_py.run(self)


class cmake_build(setuptools.Command):
    """
    Compiles everything when `python setup.py develop` is run using cmake.

    Custom args can be passed to cmake by specifying the `CMAKE_ARGS`
    environment variable.
    """
    def initialize_options(self):
        pass

    def finalize_options(self):
        pass

    def _run_cmake(self):
        with cd(CMAKE_BUILD_DIR):
            cmake_args = [
                CMAKE,
                '-DCMAKE_EXPORT_COMPILE_COMMANDS=ON',
                '-DCMAKE_BUILD_TYPE={}'.format('Debug' if DEBUG else 'Release'),
                '-DPYTHON_EXECUTABLE={}'.format(sys.executable),
                # PyTorch cmake args
                '-DPYTORCH_DIR={}'.format(
                    os.path.dirname(os.path.realpath(torch.__file__))),
                # TVM cmake args
                '-DUSE_LLVM={}'.format(LLVM_CONFIG),
                '-DTVM_DIR={}'.format(os.path.join(TOP_DIR, 'tvm')),
            ]
            if 'CMAKE_ARGS' in os.environ:
                extra_cmake_args = shlex.split(os.environ['CMAKE_ARGS'])
                log.info('Extra cmake args: {}'.format(extra_cmake_args))
                cmake_args.extend(extra_cmake_args)
            cmake_args.append(TOP_DIR)
            subprocess.check_call(cmake_args)

    def _run_build(self):
        with cd(CMAKE_BUILD_DIR):
            build_args = [
                CMAKE,
                '--build', os.curdir,
                '--', '-j', str(multiprocessing.cpu_count()),
            ]
            subprocess.check_call(build_args)
          

    def run(self):
        if not os.path.exists(CMAKE_BUILD_DIR):
            os.makedirs(CMAKE_BUILD_DIR)

        is_initial_build = not os.path.exists(os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt"))
        if is_initial_build or RERUN_CMAKE:
            self._run_cmake()

        self._run_build()


class develop(setuptools.command.develop.develop):
    def run(self):
        self.run_command('build_ext')
        setuptools.command.develop.develop.run(self)


class build_ext(setuptools.command.build_ext.build_ext):
    def run(self):
        self.run_command('cmake_build')
        setuptools.command.build_ext.build_ext.run(self)

    def build_extensions(self):
        for ext in self.extensions:
            fullname = self.get_ext_fullname(ext.name)
            filename = os.path.basename(self.get_ext_filename(fullname))

            src = os.path.join(CMAKE_BUILD_DIR, filename)
            dst = os.path.join(os.path.realpath(self.build_lib), 'torch_tvm', filename)
            if not os.path.exists(os.path.dirname(dst)):
                os.makedirs(os.path.dirname(dst))
            self.copy_file(src, dst)

class install(setuptools.command.install.install):
    def run(self):
        setuptools.command.install.install.run(self)
        mk_tvm_dir = [
            'mkdir',
            '-p',
            '{}'.format(os.path.join(TOP_DIR, 'tvm', 'build')),
        ]
        subprocess.check_call(mk_tvm_dir)
        copy_tvm_files = [
            'cp',
            '-a',
            '{}'.format(os.path.join(CMAKE_BUILD_DIR, 'tvm', '.')),
            '{}'.format(os.path.join(TOP_DIR, 'tvm', 'build')),
        ]
        subprocess.check_call(copy_tvm_files)
        with cd(os.path.join(TOP_DIR, 'tvm', 'python')):
            subprocess.check_call("{} setup.py install".format(sys.executable), shell=True)
        with cd(os.path.join(TOP_DIR, 'tvm', 'topi', 'python')):
            subprocess.check_call("{} setup.py install".format(sys.executable), shell=True)

cmdclass = {
    'cmake_build': cmake_build,
    'build_py': build_py,
    'develop': develop,
    'build_ext': build_ext,
    'install': install,
}

################################################################################
# Extensions
################################################################################

ext_modules = [
    setuptools.Extension(
        name=str('torch_tvm._torch_tvm'),
        sources=[])
]

################################################################################
# Packages
################################################################################

# no need to do fancy stuff so far
packages = setuptools.find_packages()

################################################################################
# Test
################################################################################

setup_requires.append('pytest-runner')
tests_require.append('pytest')
tests_require.append('scikit-image')

################################################################################
# Final
################################################################################

setuptools.setup(
    name="torch_tvm",
    version=VersionInfo.version,
    description="PyTorch + TVM",
    ext_modules=ext_modules,
    cmdclass=cmdclass,
    packages=packages,
    include_package_data=True,
    install_requires=install_requires,
    setup_requires=setup_requires,
    tests_require=tests_require,
    extras_require=extras_require,
    author='bddppq',
    author_email='jbai@fb.com',
    url='https://github.com/pytorch/tvm',
)
